{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding dispersion to the CHGNet pre-trained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook describes the process of adding a dispersion correction to the CHGNet pre-trained model. CHGNet is trained on PBE (GGA) DFT calculations; as such, it does not include a correction for van der Waals or dispersive forces. This kind of correction may be particularly useful for those studying porous materials, such as MOFs or zeolites, but who do not wish to fine-tune the pre-trained model on data that include a dispersion correction.\n",
    "\n",
    "This notebook uses both the [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master) and [DFT-D4](https://dftd4.readthedocs.io/en/latest/reference/ase.html) repositories to add dispersion to CHGNet. The torch-dftd repository currently has DFT-D2 and DFT-D3 implementations and does not have the most recent DFT-D4 version, but is GPU-accelerated where DFT-D4 is not. The Grimme group has released a version of [DFT-D4 implemented in PyTorch](https://github.com/dftd4/tad-dftd4); however, this version does not have an ASE-compatible calculator available.\n",
    "\n",
    "You will need to install CHGNet, [ASE](https://wiki.fysik.dtu.dk/ase/install.html), [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master?tab=readme-ov-file#install), and [DFT-D4](https://dftd4.readthedocs.io/en/latest/recipe/installation.html) to run this notebook (links are to their installation instructions)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ase.build import fcc111\n",
    "from ase.calculators.mixing import SumCalculator\n",
    "from dftd4.ase import DFTD4\n",
    "from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n",
    "\n",
    "from chgnet.model.dynamics import CHGNetCalculator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CHGNet v0.3.0 initialized with 412,525 parameters\n",
      "CHGNet will run on cpu\n",
      "CHGNet will run on cpu\n"
     ]
    }
   ],
   "source": [
    "# pre-trained chgnet model\n",
    "chgnet_calc = CHGNetCalculator()\n",
    "\n",
    "d3_calc = TorchDFTD3Calculator()  # uses PBE parameters by default\n",
    "\n",
    "d4_calc = DFTD4(method=\"PBE\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A simple example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example shows how to initialize an Atoms object (of a Cu(111) surface) and compute its energy with and without the dispersion correction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Disp calculator: sumcalculator\n",
      "Cu4\n",
      "E without dispersion: -12.876540184020996\n",
      "E with DFT-D3 dispersion: -13.272150007989707\n",
      "E with DFT-D4 dispersion: -13.573149493981\n"
     ]
    }
   ],
   "source": [
    "# Create a 2x2x1 fcc(111) Cu slab\n",
    "atoms = fcc111(\"Cu\", (2, 2, 1), vacuum=10.0)\n",
    "atoms.set_pbc([True, True, True])\n",
    "\n",
    "atoms_disp = atoms.copy()\n",
    "atoms_d4 = atoms.copy()\n",
    "\n",
    "atoms.calc = chgnet_calc\n",
    "\n",
    "chgnet_d3 = SumCalculator([chgnet_calc, d3_calc])\n",
    "chgnet_d4 = SumCalculator([chgnet_calc, d4_calc])\n",
    "atoms_disp.calc = chgnet_d3\n",
    "atoms_d4.calc = chgnet_d4\n",
    "\n",
    "e_chg = atoms.get_potential_energy()\n",
    "e_disp = atoms_disp.get_potential_energy()\n",
    "e_d4 = atoms_d4.get_potential_energy()\n",
    "\n",
    "print(f\"Disp calculator: {chgnet_d3.name}\")\n",
    "print(atoms.get_chemical_formula())\n",
    "print(f\"E without dispersion: {e_chg}\")\n",
    "print(f\"E with DFT-D3 dispersion: {e_disp}\")\n",
    "print(f\"E with DFT-D4 dispersion: {e_d4}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimization example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is a simple example of an optimization of a Cu cell with a displaced atom and perturbed unit cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ase.build import bulk\n",
    "from ase.filters import FrechetCellFilter\n",
    "from ase.optimize import BFGS\n",
    "\n",
    "atoms = bulk(\"Cu\", cubic=True)\n",
    "\n",
    "atoms[0].x += 0.1\n",
    "atoms.cell[0] += 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "atoms.cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Step     Time          Energy          fmax\n",
      "BFGS:    0 17:55:31      -18.448409        0.718576\n",
      "BFGS:    1 17:55:33      -18.479250        0.649700\n",
      "BFGS:    2 17:55:34      -18.602240     1352.772997\n",
      "BFGS:    3 17:55:35      -18.495526        0.666888\n",
      "BFGS:    4 17:55:36      -18.505485        0.674898\n",
      "BFGS:    5 17:55:37      -18.524050        0.707074\n",
      "BFGS:    6 17:55:39      -18.524685        0.708438\n",
      "BFGS:    7 17:55:40      -18.527744        0.722830\n",
      "BFGS:    8 17:55:41      -18.529436        0.740565\n",
      "BFGS:    9 17:55:42      -18.530648        0.755404\n",
      "BFGS:   10 17:55:43      -18.530939        0.757174\n",
      "BFGS:   11 17:55:44      -18.531477        0.756699\n",
      "BFGS:   12 17:55:46      -18.532681        0.750869\n",
      "BFGS:   13 17:55:47      -18.534685        0.741311\n",
      "BFGS:   14 17:55:48      -18.537005        0.728779\n",
      "BFGS:   15 17:55:49      -18.539452        0.706232\n",
      "BFGS:   16 17:55:50      -18.540848        0.684654\n",
      "BFGS:   17 17:55:52      -18.541594        0.680409\n",
      "BFGS:   18 17:55:53      -18.544656        0.659496\n",
      "BFGS:   19 17:55:54      -18.548455        0.625415\n",
      "BFGS:   20 17:55:55      -18.554630        0.578256\n",
      "BFGS:   21 17:55:56      -18.562180        0.520110\n",
      "BFGS:   22 17:55:58      -18.568905        0.460355\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "atoms.calc = chgnet_d4\n",
    "cell_filter = FrechetCellFilter(atoms)\n",
    "opt = BFGS(cell_filter, trajectory=\"Cu.traj\")\n",
    "\n",
    "opt.run(fmax=0.5, steps=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output of this optimization can be viewed by running\n",
    "\n",
    "```bash\n",
    "ase gui Cu.traj\n",
    "```\n",
    "\n",
    "in the command line in this folder."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "htvs",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
